# coding=utf-8
# Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os
import logging as logger
import pandas as pd
from example.example import read_file
from web.common.config import FILE_PATH, INLINE_TRUE
from web.common.utils import save_gragh_edges, _, update_inline_datasets
from web.models.base_class import DataSetApi
from web.models.task_db import TaskApi

from castle.metrics.evaluation import MetricsDAG


class Evaluation:
    """
    The algorithm evaluation function of the current task is performed by a single task.
    """

    def __init__(self):
        self.evaluation_metrics = ["fdr", "tpr", "fpr", "shd", "nnz", "precision", "recall", "F1", "gscore"]
        update_inline_datasets()
        self.generat_operators = INLINE_TRUE

    def get_label_checkbox(self, task_id):
        """Determine whether the current task data is generated internally or imported externally.

        Parameters
        ----------
        task_id: int
            task key in the database.
        Returns
        -------
        builtin: bool
            True: Data generated by a data generation task in the task list.
            False: Custom data.
        """
        task_api = TaskApi()
        label = task_api.get_label(task_id)
        builtin = label and label in self.generat_operators
        return builtin

    @staticmethod
    def get_task_evaluation_metrics(task_id):
        """Obtain the evaluation indicator list.

        Parameters
        ----------
        task_id: int
            task key in the database.

        Returns
        -------
        res: list


        """
        task_api = TaskApi()
        performance = task_api.get_performance(task_id)
        res = None
        if performance:
            res = list(performance.keys())
        return res

    def get_evaluation_metrics(self, task_id):
        """
        Obtains the selected evaluation indicators and all evaluation indicators of a task.

        Parameters
        ----------
        task_id: int
            task key in the database.

        Returns
        -------
        : dict

        """
        task_evaluation_metrics = self.get_task_evaluation_metrics(task_id)
        if not task_evaluation_metrics:
            task_evaluation_metrics = list()
        return {"evaluation_list": self.evaluation_metrics,
                "chosen_evaluation": task_evaluation_metrics}

    def get_task_builtin_label(self, task_id):
        """
        Obtains the built-in data name used by a task.

        Parameters
        ----------
        task_id: int
            task key in the database.

        Returns
        -------
        builtin_label: str or None
            Built-in Data Name.
        """
        task_api = TaskApi()
        label = task_api.get_label(task_id)
        if label and label in self.generat_operators:
            builtin_label = label
        else:
            builtin_label = None
        return builtin_label

    def get_builtin_label(self, task_id):
        """
        Obtains the built-in data name used by a task and built-in data name list.
        Parameters
        ----------
        task_id: int
            task key in the database.

        Returns
        -------
        : dict
        """
        task_builtin_label = self.get_task_builtin_label(task_id)
        return {"operators": self.generat_operators,
                "selected_operators": task_builtin_label}

    def get_task_customize_label(self, task_id):
        """
        Obtains the label field.

        Parameters
        ----------
        task_id: int
            task key in the database.

        Returns
        -------
        : dict
        """
        task_api = TaskApi()
        label = task_api.get_label(task_id)
        if label and label not in self.generat_operators:
            customize_label = label
        else:
            customize_label = None
        return {"label_data_path": customize_label}

    @staticmethod
    def check_label_dataset(label_path):
        """
        Check whether the path exists.

        Parameters
        ----------
        label_path: str
            dataset path

        Returns
        -------
        res : bool
            True: path exists.
            False: path not exists.
        """
        res = False
        if os.path.exists(label_path):
            if os.path.getsize(label_path):
                res = True
        return res

    def evaluation_execute(self, task_id, label_path, chosen_evaluation):
        """
        Executive evaluation.

        Parameters
        ----------
        task_id: int
            task key in the database.
        label_path: str
            Real image path or built-in data name.
        chosen_evaluation: list
            Selected evaluation indicators

        Returns
        -------
        : dict
            Evaluation Results.
        """
        task_api = TaskApi()
        est_dag = task_api.get_est_dag(task_id)

        try:
            task_path = os.path.join(FILE_PATH, 'task', task_id)
            file_name = os.path.join(task_path, "true.txt")
            true_dag = read_file(label_path, header=0)
            save_gragh_edges(true_dag, file_name)
            if isinstance(true_dag, pd.DataFrame):
                true_dag = true_dag.values
            task_api.update_true_dag(task_id, true_dag)
            metrics = MetricsDAG(est_dag, true_dag)
        except Exception as error:
            # logger.warning(_("evaluation execute failed") + ', exp=%s' % error)
            return {"status": 400, "data": str(error)}

        evaluation_metrics = dict()
        for evaluation, _ in metrics.metrics.items():
            if evaluation in chosen_evaluation:
                evaluation_metrics.update(
                    {evaluation: metrics.metrics[evaluation]})

        task_api = TaskApi()
        task_api.update_performance(task_id, label_path, evaluation_metrics)
        return {"task_id": task_id,
                "evaluations": evaluation_metrics}
